{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "edcbf738",
   "metadata": {},
   "source": [
    "## Pixtral 12b LMI v12 Deployment Guide"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b69ae92d",
   "metadata": {},
   "source": [
    "This notebook demonstrates how to deploy the [Pixtral-12b](https://huggingface.co/mistralai/Pixtral-12B-2409) model using the LMI v12 container. This model is not stored in the typical HuggingFace pretrained format, so more configurations are required to deploy this successfully. While there are community versions of this model that have been converted into the HuggingFace pretrained format, those models are not compatible with LMI v12 as they are not compatible with vLLM. \n",
    "\n",
    "If you have finetuned this model and saved the artifacts in the HuggingFace pretrained format, you will need to convert the artifacts back into the mistral format. You can read more about that process in this discussion: https://huggingface.co/mistral-community/pixtral-12b/discussions/4."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60b26e58",
   "metadata": {},
   "source": [
    "### Install Required dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b94d69b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker boto3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbb8969a",
   "metadata": {},
   "source": [
    "## Create the SageMaker model object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62631e4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import image_uris\n",
    "from sagemaker.djl_inference import DJLModel\n",
    "\n",
    "image_uri = \"763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.30.0-lmi12.0.0-cu124\"\n",
    "\n",
    "role = sagemaker.get_execution_role()\n",
    "\n",
    "# Once the SageMaker Python SDK PR is merged, we can use image_uris directly\n",
    "# image_uri = image_uris.retrieve(framework=\"djl-lmi\", version=\"0.30.0\", region=\"us-west-2\")\n",
    "\n",
    "model = DJLModel(\n",
    "    role=role,\n",
    "    image_uri=image_uri,\n",
    "    env={\n",
    "        \"HF_MODEL_ID\": \"mistralai/Pixtral-12B-2409\",\n",
    "        \"HF_TOKEN\": \"<huggingface hub token>\",\n",
    "        \"OPTION_ENGINE\": \"Python\",\n",
    "        \"OPTION_MPI_MODE\": \"true\",\n",
    "        \"OPTION_ROLLING_BATCH\": \"lmi-dist\",\n",
    "        \"OPTION_MAX_MODEL_LEN\": \"8192\", # this can be tuned depending on instance type + memory available\n",
    "        \"OPTION_MAX_ROLLING_BATCH_SIZE\": \"16\", # this can be tuned depending on instance type + memory available\n",
    "        \"OPTION_TOKENIZER_MODE\": \"mistral\",\n",
    "        \"OPTION_ENTRYPOINT\": \"djl_python.huggingface\",\n",
    "        \"OPTION_TENSOR_PARALLEL_DEGREE\": \"max\",\n",
    "        \"OPTION_LIMIT_MM_PER_PROMPT\": \"image=4\", # this can be tuned to control how many image per prompt are allowed\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57c5e907",
   "metadata": {},
   "source": [
    "## Deploy the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f022cae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = model.deploy(instance_type=\"ml.g6.12xlarge\", initial_instance_count=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25a366fe",
   "metadata": {},
   "source": [
    "## Test prompts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b9bc8cb",
   "metadata": {},
   "source": [
    "The following prompts demonstrate how to use the pixtral-12b model for:\n",
    "- Text only inference\n",
    "- Single image inference\n",
    "- Multi image inference\n",
    "\n",
    "For the multi image inference use-case, we use two images. However, the model is configured to accept up to 4 images in a single prompt. This setting can be tuned with the `OPTION_LIMIT_MM_PER_PROMPT` configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2260ced",
   "metadata": {},
   "outputs": [],
   "source": [
    "IMAGE_1_KITTEN = \"https://resources.djl.ai/images/kitten.jpg\"\n",
    "IMAGE_2_TRUCK = \"https://resources.djl.ai/images/truck.jpg\"\n",
    "\n",
    "text_only_payload = {\n",
    "    \"messages\": [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": \"I would like to get better at basketball. Can you provide me a 3 month plan to improve my skills?\"\n",
    "        }\n",
    "    ],\n",
    "    \"max_tokens\": 1024,\n",
    "    \"temperature\": 0.6,\n",
    "    \"top_p\": 0.9,\n",
    "}\n",
    "\n",
    "single_image_payload = {\n",
    "    \"messages\": [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": [\n",
    "                {\n",
    "                    \"type\": \"text\",\n",
    "                    \"text\": \"Can you describe the following image and tell me what it contains?\",\n",
    "                },\n",
    "                {\n",
    "                    \"type\": \"image_url\",\n",
    "                    \"image_url\": {\n",
    "                        \"url\": IMAGE_1_KITTEN\n",
    "                    }\n",
    "                }\n",
    "            ]\n",
    "        }\n",
    "    ],\n",
    "    \"max_tokens\": 1024,\n",
    "    \"temperature\": 0.6,\n",
    "    \"top_p\": 0.9,\n",
    "}\n",
    "\n",
    "multi_image_payload = {\n",
    "    \"messages\": [\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": [\n",
    "                {\n",
    "                    \"type\": \"text\",\n",
    "                    \"text\": \"Can you describe the following images and tell me what they have in common? If they have nothing in common, please explain why.\",\n",
    "                },\n",
    "                {\n",
    "                    \"type\": \"image_url\",\n",
    "                    \"image_url\": {\n",
    "                        \"url\": IMAGE_1_KITTEN\n",
    "                    }\n",
    "                },\n",
    "                {\n",
    "                    \"type\": \"image_url\",\n",
    "                    \"image_url\": {\n",
    "                        \"url\": IMAGE_2_TRUCK\n",
    "                    }\n",
    "                }\n",
    "            ]\n",
    "        }\n",
    "    ],\n",
    "    \"max_tokens\": 1024,\n",
    "    \"temperature\": 0.6,\n",
    "    \"top_p\": 0.9,\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2485e48f",
   "metadata": {},
   "source": [
    "# Text Only Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2848eca3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Prompt is:\\n {text_only_payload['messages'][0]['content']}\")\n",
    "text_only_output = predictor.predict(text_only_payload)\n",
    "print(\"Response is:\\n\")\n",
    "print(text_only_output['choices'][0]['message']['content'])\n",
    "print('----------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12250682",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "from io import BytesIO\n",
    "\n",
    "response_kitten = requests.get(IMAGE_1_KITTEN)\n",
    "img_kitten = Image.open(BytesIO(response_kitten.content))\n",
    "response_truck = requests.get(IMAGE_2_TRUCK)\n",
    "img_truck = Image.open(BytesIO(response_truck.content))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb094d99",
   "metadata": {},
   "source": [
    "# Single Image Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5230292",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"This is the image provided to the model\")\n",
    "img_kitten.show()\n",
    "single_image_output = predictor.predict(single_image_payload)\n",
    "print(single_image_output['choices'][0]['message']['content'])\n",
    "print('----------------------------')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc171f5",
   "metadata": {},
   "source": [
    "# Multi Image Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6996dd2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"These are the images provided to the model\")\n",
    "img_kitten.show()\n",
    "img_truck.show()\n",
    "multi_image_output = predictor.predict(multi_image_payload)\n",
    "print(multi_image_output['choices'][0]['message']['content'])\n",
    "print('----------------------------')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08f5a68a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# clean up resources\n",
    "predictor.delete_endpoint()\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
